JAX is an open-source library designed for high-performance numerical computing and machine learning research. It provides tools for automatic differentiation, GPU/TPU acceleration, and just-in-time compilation to optimize code execution. ROCm, AMD's platform for GPU computing, enables JAX to utilize the power of AMD GPUs for faster computations. A ROCm-enabled JAX container is a pre-configured, portable environment that includes JAX optimized for AMD GPUs. By using this container, developers and researchers can easily leverage AMD GPUs for their JAX-based machine learning tasks without worrying about setting up dependencies or hardware configurations, streamlining their workflow.
In this article, you are to download and run a ROCm supported JAX container, and install JAX using Pip for ROCm compute platform.
In this section, you are to download and run a ROCm supported JAX container. You are to also check the GPU availability from the container.
Pull the ROCm supported container for Jax.
$ docker pull rocm/jax:latest
Run a temporary Docker container.
$ docker run --rm -it --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --shm-size 8G rocm/jax:latest
The above command runs a temporary container enabling access to GPU devices (/dev/kfd
and /dev/dri
) for ROCm-supported JAX workloads.
Verify GPU availability from the container.
$ rocm-smi
$ python3 -c 'import jax; print(jax.devices())'
The output of the above command should list all devices along with their specifications.
Exit and destroy the temporary container.
$ exit
In this section, you are to install jaxlib
, JAX ROCm Plugin, and JAX on the host machine using Pip and check for GPU availability.
Fetch the Python version.
$ python3 -V
Fetch the ROCm version.
$ amd-smi version
Navigate to the JAX ROCm GitHub fork releases page.
Find the installation commands by matching the Python version you retrieved in step 1 with the ROCm version you retrieved in step 2 in the latest release notes.
Copy and execute the jaxlib
and JAX ROCm Plugin installation commands from the release notes into your terminal.
jaxlib
: It acts as a bridge between JAX and the hardware it runs on, such as CPUs, GPUs, or TPUs. Without jaxlib
, JAX cannot execute computations efficiently on the target hardware.Install the JAX Python package.
$ python3 -m pip install jax
Verify GPU availability.
$ python3 -c 'import jax; print(jax.devices())'
The output of the above command should display all devices along with their respective IDs.
In this article, you downloaded and ran a ROCm supported JAX container with access to ROCm supported devices. Furthermore, you also installed JAX using Pip for ROCm compute platform.